import math

import mindspore.nn as nn
from mindspore.ops import operations as P
import numpy as np
from mindspore import Tensor

__all__ = ['Activation', 'Conv2d', 'ConvBNReLU', 'Conv2dTranspose', 'UpSample', ]


class Swish(nn.Cell):
    def __init__(self):
        super().__init__()
        self.mul = P.Mul()
        self.sigmoid = P.Sigmoid()

    def construct(self, x):
        return self.mul(x, self.sigmoid(x))


class Activation(nn.Cell):
    """
    Activation definition.

    Args:
        act_func(string): activation name.

    Returns:
         Tensor, output tensor.
    """

    def __init__(self, act_func):
        super(Activation, self).__init__()
        if act_func == 'relu':
            self.act = nn.ReLU()
        elif act_func == 'relu6':
            self.act = nn.ReLU6()
        elif act_func in ('hsigmoid', 'hard_sigmoid'):
            self.act = nn.HSigmoid()
        elif act_func in ('hswish', 'hard_swish'):
            self.act = nn.HSwish()
        elif act_func == 'swish':
            self.act = Swish()
        else:
            raise NotImplementedError

    def construct(self, x):
        return self.act(x)


def Conv2d(in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False):
    return nn.Conv2d(in_planes, out_planes, kernel_size, stride, 'pad', padding, dilation, groups, has_bias=bias)


def Conv2dTranspose(in_planes, out_planes, kernel_size=4, stride=2, bias=True):
    return nn.Conv2dTranspose(in_planes, out_planes, kernel_size, stride, has_bias=bias)


class UpSample(nn.Cell):
    def __init__(self, channels, ratio=2, use_on_ascend=True):
        super().__init__()
        if ratio == 1:
            self.resize = None
            return
        # group=1 fro Ascend; group=channels for GPU
        self.resize = nn.Conv2dTranspose(channels, channels, ratio * 2, ratio,
                                         group=1 if use_on_ascend else channels, has_bias=False)

        shape = self.resize.weight.shape
        w = np.zeros(shape)
        f = math.ceil(shape[2] / 2)
        c = (2 * f - 1 - f % 2) / (2. * f)
        for i in range(shape[2]):
            for j in range(shape[3]):
                w[0, 0, i, j] = (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c))
        for c in range(1, shape[0]):
            w[c, c if shape[1] != 1 else 0, :, :] = w[0, 0, :, :]
        self.resize.weight.set_data(Tensor(w.astype(np.float32)))

    def construct(self, x):
        return x if self.resize is None else self.resize(x)


class ConvBNReLU(nn.Cell):

    def __init__(self, in_planes, out_planes, kernel_size=1, stride=1, padding=0, group=1, act='relu', bias=False):
        super(ConvBNReLU, self).__init__()
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, 'pad', padding, group=group, has_bias=bias)
        self.bn = nn.BatchNorm2d(out_planes)
        self.act = Activation(act) if act != 'none' else None

    def construct(self, x):
        x = self.conv(x)
        x = self.bn(x)
        if self.act is not None:
            x = self.act(x)
        return x


def _test():
    import torch
    import mindspore.context as context
    context.set_context(device_target="GPU")
    # Test UpSample
    # x = np.random.rand(2, 32, 64, 64).astype(np.float32)
    x = np.arange(36).reshape((1, 1, 6, 6)).astype(np.float32)
    y1 = torch.nn.functional.interpolate(torch.from_numpy(x), scale_factor=2, mode='bilinear', align_corners=False)
    y2 = UpSample(1, 2)(Tensor(x))
    y1 = y1.numpy()
    y2 = y2.asnumpy()
    np.set_printoptions(precision=2, suppress=True, linewidth=150)
    print(y1.shape, y2.shape)
    print(y1)
    print(y1 - y2)
    print(np.abs(y1 - y2).max())


def _test2():
    import torch
    import mindspore.context as context
    from extension.layers import UpSample as UpSample_torch
    from src.conver_torch_model_to_mindspore import convert
    from mindspore.train.serialization import load_param_into_net
    context.set_context(device_target="GPU")

    r = 6
    x = np.random.rand(2, 32, 64, 64).astype(np.float32)
    net = UpSample(32, r, use_on_ascend=False)
    y = net(Tensor(x))
    print(y.shape)
    net.compile_and_run(Tensor(x))

    net2 = UpSample_torch(32, r)
    net2.eval()
    load_param_into_net(net, convert(net2.state_dict()))
    y2 = net2(torch.from_numpy(x))
    y = net(Tensor(x))
    print(np.abs(y.asnumpy() - y2.cpu().detach().numpy()).max())


if __name__ == '__main__':
    _test2()
